'''
Usage: python annotation.py input output
'''

import sys
import collections
import copy
import gzip
import argparse

class gene_info:
    def __init__(self, TSS, TES, gene):
        self.TSS = TSS
        self.TES = TES
        self.gene = gene

class junction_info:
    def __init__(self, chrom1, junction1, chrom2, junction2, SV_type, support):
        self.chrom1 = chrom1
        self.chrom2 = chrom2
        self.junction1 = junction1
        self.junction2 = junction2
        self.SV_type = SV_type
        self.support = int(support)

def load_gene_info():
    out = dict()
    with open('/home/ysakam/genome/DB/cds_accGene_sorted_lex.txt') as f:
        gene_dict = dict()
        for i,line in enumerate(f):
            items = line.rstrip('\n').split('\t')

            TSS = int(items[4])
            TES = int(items[5])
            cds_chrom = items[2]
            gene = items[12]
            if gene in gene_dict:
                gene_dict[gene].append([TSS, TES, cds_chrom])
            else:
                gene_dict[gene] = [[TSS, TES, cds_chrom]]

        for gene in gene_dict:
            TSS = min([i[0] for i in gene_dict[gene]])
            TES = max([i[1] for i in gene_dict[gene]])
            chrom = gene_dict[gene][0][2]

            #if gene == "SEPT3":
             #   print(TSS, TES, chrom)
            if chrom in out:
                out[chrom].append(gene_info(TSS, TES, gene))
            else:
                out[chrom] = [gene_info(TSS, TES, gene)]

    return out

def inter_chrom_annotation(line, info):
    # for translocation
    items = line.rstrip('\n').split('\t')
    chrom1 = items[0].split(':')[0]
    chrom2 = items[1].split(':')[0]
    junction1 = int(items[0].split(':')[1])
    junction2 = int(items[1].split(':')[1])
    SV_type = items[2]
    support_read = items[3]

    gene1 = '.'
    gene2 = '.'
    k = tuple()
    if chrom1 == chrom2 or chrom1 == 'chrM' or chrom1 == 'chrEBV' or chrom2 == 'chrM' or chrom2 == 'chrEBV':
        return k, gene1, gene2

    for item in info[chrom1]:
        if item.TSS <= junction1 <= item.TES:
            gene1 = item.gene
            break
    for item in info[chrom2]:
        if item.TSS <= junction2 <= item.TES:
            gene2 = item.gene
            #if gene1 == 'EFNA5':
            #    print(gene2)
            break
    if gene1 != '.' or gene2 != '.':
        k = (chrom1, junction1, chrom2, junction2, SV_type, support_read)

    return k, gene1, gene2

def gene_annotation(ifile, info):
    annotation = dict()
    with gzip.open(ifile, 'rt') as f:
        for line in f:
            items = line.rstrip('\n').split('\t')
            chrom1 = items[0].split(':')[0]
            chrom2 = items[1].split(':')[0]
            junction1 = int(items[0].split(':')[1])
            junction2 = int(items[1].split(':')[1])
            SV_type = items[2]
            support_read = items[3]

            if chrom1 == 'chrM' or chrom1 == 'chrEBV':
                continue

            if chrom1 != chrom2:
                k, gene1, gene2 = inter_chrom_annotation(line, info)
                if k != ():
                    annotation[k] = [gene1, gene2]
                continue

            if junction1 < junction2:
                left_junction = junction1
                right_junction = junction2
            else: 
                left_junction = junction2
                right_junction = junction1

            k = (chrom1, junction1, chrom2, junction2, SV_type, support_read)
            genes = list()
            for item in info[chrom1]:
                #gene_len = item.TES - item.TSS + 1
                if item.gene in genes:
                    continue
                if item.TSS <= left_junction <= item.TES:
                    genes.append(item.gene)
                elif left_junction < item.TSS and item.TES < right_junction:
                    genes.append(item.gene)
                elif item.TSS <= right_junction <= item.TES:
                    genes.append(item.gene)
                else:
                    continue
            if genes != []:
                annotation[k] = genes

    with open('junction.annotated.txt', 'w') as w:
        for k in annotation:
            print(k[0], k[1], k[2], k[3], k[4], k[5], ','.join([i for i in annotation[k]]), sep='\t', file=w)

    return annotation


def junction_merging(annotation, threshold, output):
    junction = {'DEL':dict(), 'INV':dict(), 'TRA':dict(), 'DUP':dict()}
    for key in annotation:
        info = junction_info(key[0], key[1], key[2], key[3], key[4], key[5])
        genes = annotation[key]
        if tuple(genes) in junction[info.SV_type]:
            junction[info.SV_type][tuple(genes)].append(info)
        else:
            junction[info.SV_type][tuple(genes)] = [info]

    with open(output, 'w') as w:
        print('CHROM1\tPOS1\tCHROM2\tPOS2\tREADS\tTYPE\tGENES', file=w)
        for SVtype in junction:
            for gene in junction[SVtype]:
                j1 = [item.junction1 for item in junction[SVtype][gene]]
                j2 = [item.junction2 for item in junction[SVtype][gene]]
                reads = [int(item.support) for item in junction[SVtype][gene]]
                index = [i for i, x in enumerate([num > 1 for num in reads]) if x == True]
                if len(index) == 1:
                    i = index[0]
                    junction1 = j1[i]
                    junction2 = j2[i]
                    count = reads[i]
                    for k, item in enumerate(junction[SVtype][gene]):
                        if k == i:
                            continue
                        else:
                            if abs(junction1 - item.junction1) < 50 and abs(junction2 - item.junction2) < 50:
                                count += item.support
                    if count >= threshold:
                        print(junction[SVtype][gene][0].chrom1, junction1, junction[SVtype][gene][0].chrom2, junction2, count, SVtype, ','.join([i for i in gene]), sep='\t', file=w)

                elif len(index) == 2:
                    i1 = index[0]
                    i2 = index[1]
                    SV_type = ''
                    if abs(j1[i1] - j1[i2]) < 50 and abs(j2[i1] - j2[i2]) < 50:
                        if reads[i1] > reads[i2]:
                            junction1 = j1[i1]
                            junction2 = j2[i1]
                            count = reads[i1] + reads[i2]
                        else:
                            junction1 = j1[i2]
                            junction2 = j2[i2]
                            count = reads[i1] + reads[i2]
                        for k, item in enumerate(junction[SVtype][gene]):
                            if k in index:
                                continue
                            if abs(junction1 - item.junction1) < 50 and abs(junction2 - item.junction2) < 50:
                                count += item.support
                        if count >= threshold:
                            print(junction[SVtype][gene][0].chrom1, junction1, junction[SVtype][gene][0].chrom2, junction2, count, SVtype, ','.join([i for i in gene]), sep='\t', file=w)

                    else:
                        for i in index:
                            junction1 = j1[i]
                            junction2 = j2[i]
                            count = reads[i]
                            #SV_type = ''
                            for k, item in enumerate(junction[SVtype][gene]):
                                if k in index:
                                    continue
                                else:
                                    if abs(junction1 - item.junction1) < 50 and abs(junction2 - item.junction2) < 50:
                                        count += item.support
                            if count >= threshold:
                                print(junction[SVtype][gene][0].chrom1, junction1, junction[SVtype][gene][0].chrom2, junction2, count, SVtype, ','.join([i for i in gene]), sep='\t', file=w)

                elif len(index) == 0:
                    # candidate of junction1
                    collect1 = collections.Counter(j1).most_common()[0]
                    junction1 = collect1[0]
                    
                    # cabdudate of junction2
                    junc2_can = [j2[i] for i, x in enumerate(j1) if x == junction1]
                    collect2 = collections.Counter(junc2_can).most_common()[0]
                    junction2 = collect2[0]

                    count = 0
                    #SV_type = ''
                    for k, item in enumerate(junction[SVtype][gene]):
                        if abs(junction1 - item.junction1) < 50 and abs(junction2 - item.junction2) < 50:
                            count += item.support
                    if count >= threshold:
                        print(junction[SVtype][gene][0].chrom1, junction1, junction[SVtype][gene][0].chrom2, junction2, count, SVtype, ','.join([i for i in gene]), sep='\t', file=w)

                else:
                    merged_index = {}
                    cp_index = copy.copy(index)
                    for i in range(len(index)):
                        ind = cp_index[i]
                        if ind == -1:
                            continue
                        ex_index = [i for i, j in enumerate(index) if abs(j1[j] - j1[ind]) < 50 and abs(j2[j] - j2[ind]) < 50]
                        num_reads = [reads[index[j]] for j in ex_index]
                        max_index = ex_index[num_reads.index(max(num_reads))]
                        count = sum(num_reads)
                        merged_index[index[max_index]] = count
                        for k in ex_index:
                            cp_index[k] = -1
                    for j in merged_index:
                        junction1 = j1[j]
                        junction2 = j2[j]
                        count = merged_index[j]
                        SV_type = ''
                        for k, item in enumerate(junction[SVtype][gene]):
                            if k in index:
                                #print(k)
                                continue
                            if abs(junction1 - item.junction1) < 50 and abs(junction2 - item.junction2) < 50:
                                count += item.support
                        if count >= threshold:
                            print(junction[SVtype][gene][0].chrom1, junction1, junction[SVtype][gene][0].chrom2, junction2, count, SVtype, ','.join([i for i in gene]), sep='\t', file=w)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--junction", help="PATH of junction.annotated.txt.gz file")
    parser.add_argument("--output", help="PATH of output file")
    parser.add_argument("--threshold", type=int, default=5, help="threshold of supporting reads")
    args = parser.parse_args()

    info = load_gene_info()
    annotation = gene_annotation(args.junction, info)
    #print(annotation)
    junction_merging(annotation, args.threshold, args.output)
